import anndata as ad
import hdf5plugin
import scipy
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
import pandas as pd
import math
import logging
import os
import scanpy as sc
import wandb
import pickle
import argparse
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler
from tqdm import tqdm
from copy import deepcopy
from dance.utils import set_seed
from scipy.sparse import csr_matrix
from CellBert.utils.eval import downstream_eval
from CellBert.utils.data import XDict, stratified_sample_genes_by_sparsity
from CellBert.utils.mask import InputDropoutMaskBuilder
from CellBert.model import OmicsFormer
from gears import PertData
from gears.data_utils import get_dropout_non_zero_genes, rank_genes_groups_by_cov
# wandb.login()

class CosineAnnealingWarmupRestarts(_LRScheduler):    
    def __init__(self, optimizer, first_cycle_steps, cycle_mult=1., max_lr=0.1, min_lr=0.001,
                 warmup_steps=0, gamma=1., last_epoch=-1):
        assert warmup_steps < first_cycle_steps
        
        self.first_cycle_steps = first_cycle_steps # first cycle step size
        self.cycle_mult = cycle_mult # cycle steps magnification
        self.base_max_lr = max_lr # first max learning rate
        self.max_lr = max_lr # max learning rate in the current cycle
        self.min_lr = min_lr # min learning rate
        self.warmup_steps = warmup_steps # warmup step size
        self.gamma = gamma # decrease rate of max learning rate by cycle
        
        self.cur_cycle_steps = first_cycle_steps # first cycle step size
        self.cycle = 0 # cycle count
        self.step_in_cycle = last_epoch # step size of the current cycle
        
        super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)
        # set learning rate min_lr
        self.init_lr()

    def init_lr(self):
        self.base_lrs = []
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.min_lr
            self.base_lrs.append(self.min_lr)
    
    def get_lr(self):
        if self.step_in_cycle == -1:
            return self.base_lrs
        elif self.step_in_cycle < self.warmup_steps:
            return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs]
        else:
            return [base_lr + (self.max_lr - base_lr) \
                    * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \
                                    / (self.cur_cycle_steps - self.warmup_steps))) / 2
                    for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
            self.step_in_cycle = self.step_in_cycle + 1
            if self.step_in_cycle >= self.cur_cycle_steps:
                self.cycle += 1
                self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
                self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
        else:
            if epoch >= self.first_cycle_steps:
                if self.cycle_mult == 1.:
                    self.step_in_cycle = epoch % self.first_cycle_steps
                    self.cycle = epoch // self.first_cycle_steps
                else:
                    n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
                    self.cycle = n
                    self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1))
                    self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n)
            else:
                self.cur_cycle_steps = self.first_cycle_steps
                self.step_in_cycle = epoch
                
        self.max_lr = self.base_max_lr * (self.gamma**self.cycle)
        self.last_epoch = math.floor(epoch)
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr
        
def main(task=None, config=None):
    # global gene_list, batch_labels, seq_list, order_list, coord_list, label_list
    tune_flag = True if config is None else False
    wandb.init(group=f"pert_{args.dataset}_{args.pre_model}")

    if tune_flag:
        config = wandb.config
    if task is None:
        task = config['head_type']

    wandb.config = config

    config["gene_list"] = pretrained_gene_list
    config["batch_num"] = batch_labels.max() + 1
    device = torch.device('cuda')

    model = OmicsFormer(**config)
    pretrained_file = f'{args.pre_model}.pt'
    pretrained_model_dict = torch.load(pretrained_file)
    pretrained_model_dict = {k[7:]: v for k, v in pretrained_model_dict.items()} # remove "module."
    try:
        model.load_state_dict(pretrained_model_dict)
    except:
        # only load params that are in the model and match the size
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_model_dict.items()
            if k in model_dict and v.shape == model_dict[k].shape
        }
        for k, v in pretrained_dict.items():
            print(f"Loading params {k} with shape {v.shape}")
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
    model.latent.layers[1].mean = pretrained_model_dict['latent.layers.1.mean'].requires_grad_(True)
    model.latent.layers[1].std = pretrained_model_dict['latent.layers.1.std'].requires_grad_(True)

    print(model)
    model.to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=config['wd'])
    # scheduler = ReduceLROnPlateau(optim, 'max', patience=10, factor=0.9)
    scheduler = CosineAnnealingWarmupRestarts(
        optim,
        first_cycle_steps=15,
        cycle_mult=2,
        max_lr=config['lr'],
        min_lr=1e-7,
        warmup_steps=5,
        gamma=0.9
    )

    train_loss = []
    valid_loss = []
    valid_metric = []
    # for epoch in (pbar := tqdm(range(config['epochs']))):
    for epoch in tqdm(range(config['epochs'])):
        epoch_loss = []
        model.train()
        for i in range(len(train_list)):
            if len(train_list[i]) > 0:
                input_dict = {
                    'x_seq': train_list[i].to(device),
                    'batch': torch.tensor(batch_labels[train_batch_list[i]]).to(device),
                    'coord': coord[train_batch_list[i]].to(device),
                    'label': train_label_list[i].to(device),
                }
                x_dict = XDict(input_dict)
                out_dict, loss = model(x_dict, input_gene_list[i])  
                optim.zero_grad()
                loss.backward()
                optim.step()
                epoch_loss.append(loss.item())
        train_loss.append(sum(epoch_loss) / len(epoch_loss))
        scheduler.step()

        with torch.no_grad():
            model.eval()
            epoch_loss = []
            valid_epoch = []
            for i in range(len(valid_list)):
                if len(valid_list[i]) > 0:
                    input_dict = {
                        'x_seq': valid_list[i].to(device),
                        'batch': torch.tensor(batch_labels[valid_batch_list[i]]).to(device),
                        'coord': coord[valid_batch_list[i]].to(device),
                        'label': valid_label_list[i].to(device),
                    }
                    x_dict = XDict(input_dict)
                    out_dict, loss = model(x_dict, input_gene_list[i])  
                epoch_loss.append(loss.item())
                valid_scores = downstream_eval(task, out_dict['pred'], x_dict['label'], top_de_dict=top_de_dict)
                valid_epoch.append(valid_scores['all_rmse'])

                if len(test_list[i]) > 0:
                    input_dict = {
                        'x_seq': test_list[i].to(device),
                        'batch': torch.tensor(batch_labels[test_batch_list[i]]).to(device),
                        'coord': coord[test_batch_list[i]].to(device),
                        'label': test_label_list[i].to(device),
                    }
                    x_dict = XDict(input_dict)
                    out_dict, loss = model(x_dict, input_gene_list[i]) 
                test_scores = downstream_eval(task, out_dict['pred'], x_dict['label'], top_de_dict=top_de_dict)
        valid_loss.append(sum(epoch_loss) / len(epoch_loss))
        valid_metric.append(sum(valid_epoch) / len(valid_epoch))
        # pbar.set_description(f'Epoch {epoch} | Train loss: {train_loss[-1]:.4f} | Valid loss: {valid_loss[-1]:.4f}')
        # if tune_flag:
            #     wandb.log({"train": train_loss[-1], "valid": valid_loss[-1]})
        if task == 'perturbation_prediction':
            print(f'Epoch {epoch} | Train loss: {train_loss[-1]:.4f} | Valid loss: {valid_loss[-1]:.4f}')
            print(f'Valid RMSE: {valid_scores["all_rmse"]:.4f} | Test RMSE: {test_scores["all_rmse"]:.4f}')
            print(f'Valid Corr: {valid_scores["all_corr"]:.4f} | Test Corr: {test_scores["all_corr"]:.4f}')
            print(f'Valid DE RMSE: {valid_scores["top_de_rmse"]:.4f} | Test DE RMSE: {test_scores["top_de_rmse"]:.4f}')
            print(f'Valid DE Corr: {valid_scores["top_de_corr"]:.4f} | Test DE Corr: {test_scores["top_de_corr"]:.4f}')

            wandb.log({
                "train": train_loss[-1], 
                "valid": valid_loss[-1],
                "valid_rmse": valid_scores["all_rmse"],
                "test_rmse": test_scores["all_rmse"],
                "valid_corr": valid_scores["all_corr"],
                "test_corr": test_scores["all_corr"],
                "valid_cos": valid_scores["all_cos"],
                "test_cos": test_scores["all_cos"],
                "valid_de_rmse": valid_scores["top_de_rmse"],
                "test_de_rmse": test_scores["top_de_rmse"],
                "valid_de_corr": valid_scores["top_de_corr"],
                "test_de_corr": test_scores["top_de_corr"],
                "valid_de_cos": valid_scores["top_de_cos"],
                "test_de_cos": test_scores["top_de_cos"],
            })

        if min(valid_metric) == valid_metric[-1]:
            temp = deepcopy(model.state_dict())
        # if epoch > 0 and min(valid_loss[-20:]) != min(valid_loss):
        #     print('Early stopped.')
        #     break

    # Inference
    model.load_state_dict(temp)
    pred = []
    label = []
    order = []
    control = []
    test_batches = []
    model.eval()
    with torch.no_grad():
        for i in range(len(test_list)):
            if len(test_list[i]) > 0:
                input_dict = {
                    'x_seq': test_list[i].to(device),
                    'batch': torch.tensor(batch_labels[test_batch_list[i]]).to(device),
                    'coord': coord[test_batch_list[i]].to(device),
                    'label': test_label_list[i].to(device),
                }
                x_dict = XDict(input_dict)
                out_dict, loss = model(x_dict, input_gene_list[i]) 
                control.append(test_list[i].to_dense()[:, :-(len(input_gene_list[i]) - len(data_genes))])
                order.append(test_batch_list[i])
                label.append(test_label_list[i])
                pred.append(out_dict['pred'].cpu())
                test_batches.append(batches[i])
        del loss, out_dict, model
    torch.cuda.empty_cache()

    if task == 'perturbation_prediction':
        pred = torch.cat(pred)[np.concatenate(order)]
        label = torch.cat(label)[np.concatenate(order)]
        control = torch.cat(control)[np.concatenate(order)]
        test_batches = np.concatenate(test_batches)[np.concatenate(order)]
        scores = downstream_eval(task, pred, label, top_de_dict=top_de_dict)
        scores_batch = downstream_eval(task, pred, label, top_de_dict=top_de_dict, batch_labels=test_batches)
        scores_delta = downstream_eval(task, pred, label, top_de_dict=top_de_dict, batch_labels=test_batches, control_level=control)
        print(f'scores: {scores}')
        print(f'scores_batch: {scores_batch}')
        print(f'scores_delta: {scores_delta}')
        if tune_flag:
            wandb.log({
                'final_rmse': scores['all_rmse'],
                'final_corr': scores['all_corr'],
                'final_cos': scores['all_cos'],
                'final_de_rmse': scores['top_de_rmse'],
                'final_de_corr': scores['top_de_corr'],
                'final_de_cos': scores['top_de_cos'],

                'final_rmse_batch': scores_batch['all_rmse'],
                'final_corr_batch': scores_batch['all_corr'],
                'final_cos_batch': scores_batch['all_cos'],
                'final_de_rmse_batch': scores_batch['top_de_rmse'],
                'final_de_corr_batch': scores_batch['top_de_corr'],
                'final_de_cos_batch': scores_batch['top_de_cos'],

                'final_rmse_delta': scores_delta['all_rmse'],
                'final_corr_delta': scores_delta['all_corr'],
                'final_cos_delta': scores_delta['all_cos'],
                'final_de_rmse_delta': scores_delta['top_de_rmse'],
                'final_de_corr_delta': scores_delta['top_de_corr'],
                'final_de_cos_delta': scores_delta['top_de_cos'],
            })
            wandb.finish()

    # del res, y, c, df
    del pred, label, temp
    torch.cuda.empty_cache()

def create_sparse_tensor(x):
    return torch.sparse_csr_tensor(x.indptr, x.indices, x.data, (x.shape[0], x.shape[1])).to_sparse().float()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default='perturbation_prediction')
    parser.add_argument("--dataset", type=str, default='adamson') # norman, dixit, adamson
    parser.add_argument("--split", type=str, default='single') # single, combo_seen0, combo_seen1, combo_seen2
    parser.add_argument("--latent_mod", type=str, default='gmvae')
    parser.add_argument("--pre_model", type=str, default='20230513_20M_12M')
                        # ['20230510_50M_12M', '20230506_12M' (36M), '20230510_10M_12M', '20230513_20M_12M']
    parser.add_argument("--batch_feat", action='store_true')          
    parser.add_argument("--seed", type=int, default=10)
    parser.add_argument("--epochs", type=int, default=500)
    parser.add_argument("--pert_fill", type=float, default=-100)
    parser.add_argument("--tune", action='store_true')
    args = parser.parse_args()
    set_seed(args.seed)
    task = args.task
    torch.set_num_threads(32)

    with (open(f"{args.pre_model}.config.pkl", "rb")) as openfile:
        config = pickle.load(openfile)
    pretrained_gene_list = config['gene_list']

    # Data Setup
    pert_data_loader = PertData('./data/pert_data')
    # data = ad.read_h5ad(f'./data/pert_data/{args.dataset}/{args.dataset}_filtered.h5ad')
    pert_data_loader.load(data_path=f'./data/pert_data/{args.dataset}')
    data = pert_data_loader.prepare_split(split=args.split, seed=args.seed)
    rank_genes_groups_by_cov(data, groupby='condition_name', covariate='cell_type', control_group='ctrl_1', 
                             n_genes=len(data.var), key_added = 'rank_genes_groups_cov_all')
    data = get_dropout_non_zero_genes(data)
    data_genes = data.var.gene_name.values
    ensg2idx = {j:i for i,j in enumerate(data.var.index)}
    top_de_dict = data.uns['top_non_dropout_de_20']
    for k in top_de_dict.keys():
        top_de_dict[k] = np.vectorize(ensg2idx.get)(top_de_dict[k])

    control_data = data[data.obs.control == 1]
    pert_data = data[data.obs.control == 0]

    batch_labels = LabelEncoder().fit_transform(pert_data.obs['condition'])
    pert = [str(x).split('+') for x in pert_data.obs['condition']]
    pert = np.array([np.array(x)[np.array(x) != 'ctrl'] for x in pert])

    train_list = []
    valid_list = []
    test_list = []
    train_label_list = []
    valid_label_list = []
    test_label_list = []
    train_batch_list = []
    valid_batch_list = []
    test_batch_list = []
    input_gene_list = []
    batches = []
    for batch in tqdm(range(batch_labels.max() + 1)):
        x = control_data.X.A
        y = torch.tensor(pert_data[batch_labels == batch].X.A).float()
        pert_label = pert[batch_labels == batch][0]
        input_gene_list.append(np.concatenate([data_genes, pert_label]))
        pert_input = args.pert_fill * np.ones([len(x), len(pert_label)])
        x = np.hstack([x, pert_input])
        
        batch_splits = pert_data.obs.split.values[batch_labels == batch]
        train_batch = np.arange(len(batch_splits))[batch_splits == 'train']
        if len(train_batch) > 0:
            train_idx = np.random.choice(len(x), len(train_batch), replace=False)
            train_list.append(create_sparse_tensor(csr_matrix(x[train_idx])))
            train_batch_list.append(torch.tensor(train_batch).int())
            train_label_list.append(y[train_batch])
        else:
            train_list.append([])
            train_batch_list.append([])
            train_label_list.append([])

        valid_batch = np.arange(len(batch_splits))[batch_splits == 'val']
        if len(valid_batch) > 0:
            valid_idx = np.random.choice(len(x), len(valid_batch), replace=False)
            valid_list.append(create_sparse_tensor(csr_matrix(x[valid_idx])))
            valid_batch_list.append(torch.tensor(valid_batch).int())
            valid_label_list.append(y[valid_batch])
        else:
            valid_list.append([])
            valid_batch_list.append([])
            valid_label_list.append([])

        test_batch = np.arange(len(batch_splits))[batch_splits == 'test']
        if len(test_batch) > 0:
            test_idx = np.random.choice(len(x), len(test_batch), replace=False)
            test_list.append(create_sparse_tensor(csr_matrix(x[test_idx])))
            test_batch_list.append(torch.tensor(test_batch).int())
            test_label_list.append(y[test_batch])
            batches.append(batch_labels[batch_labels == batch])
        else:
            test_list.append([])
            test_batch_list.append([])
            test_label_list.append([])
            batches.append([])
    
    if not args.batch_feat:
        batch_labels = torch.zeros(pert_data.shape[0]).int()
    coord = torch.zeros(pert_data.shape[0], 2) - 1
    
    out_dim = y.shape[1]
    del data, x

    if args.tune:
        param_dict = {
            "head_type": {'values': [task]},
            "mask_type": {'values': ['hidden']},
            "dec_mod": {'values': ['mlp']},
            "dec_hid": {'values': [128, 64, 256]},
            "dec_layers": {'values': [2, 3, 4]},
            "model_dropout": {'values': [0.3, 0.1, 0, 0.5, 0.7]},
            "mask_node_rate": {'values': [0.7, 0.1, 0.3, 0.5]},
            "mask_feature_rate": {'values': [0.5, 0.1, 0.3, 0.7]},
            "drop_node_rate": {'values': [0]},
            "architecture": {'values': ["OmicsFormer"]},
            "epochs": {'values': [args.epochs]},
            "norm": {'values': ["layernorm"]},
            "wd": {'values': [0, 1e-8]},
            "w_li": {'values': [0]},
            "w_en": {'values': [0]},
            "w_ce": {'values': [0]},
            "out_dim": {'values': [out_dim]},
            "batch_feat": {'values': [args.batch_feat]},
        }

        # ['20230510_50M_12M', '20230506_12M' (36M), '20230510_10M_12M']
        if args.pre_model == '20230510_50M_12M':
            param_dict["enc_mod"] = {'values': ['performer']}
            param_dict["latent_mod"] = {'values': ['gmvae']}
            param_dict["enc_hid"] = {'values': [512]}
            param_dict["enc_layers"] = {'values': [12]}
            param_dict["post_latent_dim"] = {'values': [64]}
            param_dict["gumbel_softmax"] = {'values': [False]}
            param_dict["num_clusters"] = {'values': [16]}
            param_dict["lr"] = {'values': [1e-6]}
        elif args.pre_model == '20230506_12M':
            param_dict["enc_mod"] = {'values': ['performer']}
            param_dict["latent_mod"] = {'values': ['gmvae']}
            param_dict["enc_hid"] = {'values': [512]}
            param_dict["enc_layers"] = {'values': [8]}
            param_dict["post_latent_dim"] = {'values': [64]}
            param_dict["gumbel_softmax"] = {'values': [False]}
            param_dict["num_clusters"] = {'values': [2]}
            param_dict["lr"] = {'values': [1e-5]}
        elif args.pre_model == '20230510_10M_12M':
            param_dict["enc_mod"] = {'values': ['performer']}
            param_dict["latent_mod"] = {'values': ['gmvae']}
            param_dict["enc_hid"] = {'values': [256]}
            param_dict["enc_layers"] = {'values': [4]}
            param_dict["post_latent_dim"] = {'values': [64]}
            param_dict["gumbel_softmax"] = {'values': [False]}
            param_dict["num_clusters"] = {'values': [8]}
            param_dict["lr"] = {'values': [1e-4]}
        elif args.pre_model == '20230513_20M_12M':
            param_dict["enc_mod"] = {'values': ['performer']}
            param_dict["latent_mod"] = {'values': ['gmvae']}
            param_dict["enc_hid"] = {'values': [384]}
            param_dict["enc_layers"] = {'values': [4]}
            param_dict["post_latent_dim"] = {'values': [64]}
            param_dict["gumbel_softmax"] = {'values': [False]}
            param_dict["num_clusters"] = {'values': [16]}
            param_dict["lr"] = {'values': [1e-4, 1e-5]}
        elif args.pre_model == '20230515_20M_12M':
            param_dict["enc_mod"] = {'values': ['performer']}
            param_dict["latent_mod"] = {'values': ['gmvae']}
            param_dict["enc_hid"] = {'values': [384]}
            param_dict["enc_layers"] = {'values': [4]}
            param_dict["post_latent_dim"] = {'values': [64]}
            param_dict["gumbel_softmax"] = {'values': [False]}
            param_dict["num_clusters"] = {'values': [16]}
            param_dict["lr"] = {'values': [1e-4]}
        elif args.pre_model == '20230515_10M_12M':
            param_dict["enc_mod"] = {'values': ['performer']}
            param_dict["latent_mod"] = {'values': ['gmvae']}
            param_dict["enc_hid"] = {'values': [256]}
            param_dict["enc_layers"] = {'values': [4]}
            param_dict["post_latent_dim"] = {'values': [64]}
            param_dict["gumbel_softmax"] = {'values': [False]}
            param_dict["num_clusters"] = {'values': [16]}
            param_dict["lr"] = {'values': [1e-4]}
        sweep_configuration = {
            'method': 'bayes',
            'name': 'tuning-den',
            'metric': {
                'goal': 'minimize',
                'name': 'final_rmse'
            },
            'parameters': param_dict,
        }
        sweep_id = wandb.sweep(sweep=sweep_configuration, project='CellBert')
        print(sweep_id)
        wandb.agent(sweep_id=sweep_id, function=main, count=1000)

        # CUDA_VISIBLE_DEVICES=2 python perturbation_prediction.py --dataset adamson --split single --pert_fill -100 --tune --pre_model 20230515_10M_12M
        # CUDA_VISIBLE_DEVICES=3 python perturbation_prediction.py --dataset dixit --split single --pert_fill -100 --tune --pre_model 20230515_10M_12M --epochs 2000
        # CUDA_VISIBLE_DEVICES=4 python perturbation_prediction.py --dataset norman --pre_model 20230515_10M_12M --tune --pert_fill -100 --split combo_seen0 --epochs 100

        # wandb.agent(sweep_id="u5sdk3u8", function=main, count=1000) # 20230510_10M_12M, adamson
        # wandb.agent(sweep_id="vp4p5fb6", function=main, count=1000) # 20230510_10M_12M, dixit
        # wandb.agent(sweep_id="u5sdk3u8", function=main, count=1000) # 20230510_10M_12M, norman, combo_seen0
        # wandb.agent(sweep_id="u5sdk3u8", function=main, count=1000) # 20230510_10M_12M, norman, combo_seen1
        # wandb.agent(sweep_id="u5sdk3u8", function=main, count=1000) # 20230510_10M_12M, norman, combo_seen2

        # wandb.agent(sweep_id="u5sdk3u8", function=main, count=1000) # 20230510_20M_12M, adamson
        # wandb.agent(sweep_id="vp4p5fb6", function=main, count=1000) # 20230510_20M_12M, dixit
        # wandb.agent(sweep_id="u5sdk3u8", function=main, count=1000) # 20230510_20M_12M, norman, combo_seen0
        # wandb.agent(sweep_id="u5sdk3u8", function=main, count=1000) # 20230510_20M_12M, norman, combo_seen1
        # wandb.agent(sweep_id="u5sdk3u8", function=main, count=1000) # 20230510_20M_12M, norman, combo_seen2

    else:
        config['mask_type'] = 'hidden'
        config['dec_mod'] = 'mlp'
        config['dec_hid'] = 128
        config['dec_layers'] = 2
        config['model_dropout'] = 0.3
        config['mask_node_rate'] = 0.7
        config['mask_feature_rate'] = 0.5
        config['drop_node_rate'] = 0.
        config['epochs'] = args.epochs
        config['lr'] = 1e-4
        config['wd'] = 0
        config['w_li'] = 0.
        config['w_en'] = 0.
        config['w_ce'] = 0.
        config['gumbel_softmax'] = True
        config['head_type'] = task
        config['out_dim'] = out_dim
        main(task, config)

